import torch
import torch.nn as nn
from knowledge_tracing.args import ARGS
from knowledge_tracing.network.util_network import get_constraint_losses


class DKVMN(nn.Module):
    """
    Extension of Memory-Augmented Neural Network (MANN)
    """

    def __init__(self, device, encoder_features, summary_dim, concept_num):
        super().__init__()
        self.device = device
        self._summary_dim = summary_dim
        self._concept_num = concept_num

        # embedding layers
        self.encoder_embedding_layers = torch.nn.ModuleDict({
            feature.name: feature.embed_layer(dim)
            for feature, dim in encoder_features
        })
        for feature, dim in encoder_features:
            if feature.name == 'item_idx':
                self._key_dim = dim
            elif feature.name == 'interaction_idx':
                self._value_dim = dim

        # FC layers
        self._erase_layer = nn.Sequential(
            nn.Linear(in_features=self._value_dim,
                      out_features=self._value_dim),
            nn.Sigmoid()
        )
        self._add_layer = nn.Sequential(
            nn.Linear(in_features=self._value_dim,
                      out_features=self._value_dim),
            nn.Tanh()
        )
        self._summary_layer = nn.Sequential(
            nn.Linear(in_features=self._value_dim + self._key_dim,
                      out_features=summary_dim),
            nn.Tanh()
        )
        self._output_layer = nn.Sequential(
            nn.Linear(in_features=summary_dim, out_features=1),
            nn.Sigmoid()
        )

        # key memory matrix, transposed and initialized
        self._key_memory = nn.Parameter(torch.randn((self._key_dim, self._concept_num), requires_grad=True).to(self.device))
        # init value memory, transposed and initialized
        self._init_value_memory = nn.Parameter(torch.randn((self._value_dim, self._concept_num), requires_grad=True).to(self.device))
        self._value_memory = None

        # xavier initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        self.augmentations = ARGS.augmentations

    def _shift_tensor(self, x):
        """
        Shift tensor of shape (bsz, seq_len) by one
        """
        bsz = x.shape[0]
        shifted_x = torch.cat((torch.zeros([bsz, 1], dtype=torch.long, device=self.device), x[:, :-1]), dim=-1)
        return shifted_x

    def _compute_correlation_weight(self, question_vector):
        """
        Args:
            compute correlation weight of a given question with key memory matrix
            question_vector: tensor of shape (batch_size, key_dim)
        Return:
            correlation_weight: tensor of shape (batch_size, concept_num)
        """
        return torch.matmul(question_vector, self._key_memory).softmax(dim=-1).to(self.device)

    def _read(self, question_vector, value_memory):
        """
        Args:
            read process - get read content vector from question_id and value memory matrix
            question_vector: tensor of shape (batch_size, key_dim)
        Return:
            read_content: tensor of shape (batch_size, value_dim)
        """
        correlation_weight = self._compute_correlation_weight(question_vector)  # (batch_size, concept_num)
        read_content = torch.matmul(value_memory, correlation_weight.unsqueeze(-1)).squeeze(-1)
        return read_content

    def _write(self, interaction_id, question_vector, value_memory):
        """
        write process - update value memory matrix
        Args:
            interaction_id: integer tensor of shape (batch_size)
            question_vector: tensor of shape (batch_size, key_dim)
        """
        interaction_vector = self.encoder_embedding_layers['interaction_idx'](interaction_id)

        e = self._erase_layer(interaction_vector)  # erase vector, (batch_size, value_dim)
        a = self._add_layer(interaction_vector)  # add vector, (batch_size, value_dim)

        w = self._compute_correlation_weight(question_vector)  # (batch_size, concept_num)
        erase = torch.matmul(w.unsqueeze(-1), e.unsqueeze(1))
        erase = torch.transpose(erase, 1, 2)  # (batch_size, value_dim, concept_num)
        add = torch.matmul(w.unsqueeze(-1), a.unsqueeze(1))
        add = torch.transpose(add, 1, 2)  # (batch_size, value_dim, concept_num)
        new_value_memory = value_memory * (1 - erase) + add
        return new_value_memory

    def forward(self, data):
        """
        Args:
            data: A dictionary of dictionary of tensors. keys ('ori', 'rep', 'ins', 'del')
            represents whether the data is an original or augmented version.
        """
        # initialize value memory matrix
        batch_size = data['ori']['interaction_idx'].shape[0]
        value_memory = torch.cat([self._init_value_memory.unsqueeze(0) for _ in range(batch_size)], 0)

        last_output = {}
        aug_losses = {}

        for aug in data:
            aug_output = None
            if aug == 'ori' or self.training:
                aug_interactions = self._shift_tensor(data[aug]['interaction_idx'])
                for i in range(ARGS.seq_size):
                    # i-th interaction
                    aug_interaction_i = aug_interactions[:, i]
                    aug_question_i = data[aug]['item_idx'][:, i]
                    aug_question_vec = self.encoder_embedding_layers['item_idx'](aug_question_i)  # (batch_size, key_dim)

                    # write
                    value_memory = self._write(aug_interaction_i, aug_question_vec, value_memory)

                    # read
                    read_content = self._read(aug_question_vec, value_memory)  # (batch_size, value_dim)

                    summary_vector = self._summary_layer(torch.cat((read_content, aug_question_vec), dim=-1))  # (batch_size, summary_dim)
                    output = self._output_layer(summary_vector)  # (batch_size, 1)

                    if i == 0:
                        aug_output = output
                    else:
                        aug_output = torch.cat((aug_output, output), dim=1)

                last_output[aug] = aug_output.unsqueeze(-1)  # (batch_size, seq_size, 1)

        # constraint losses
        if self.training:
            aug_losses = get_constraint_losses(data, last_output)

        if len(aug_losses) == 0:
            return last_output, None
        else:
            return last_output, aug_losses
